# -*- coding: utf-8 -*-
import numpy as np
import cv2
from PIL import Image
import torch
from torch.utils.data import Dataset
from utils.utils import preprocess_input, cvtColor



class DataGenerator(Dataset):
    def __init__(self, annotation_lines, input_shape, random=True):
        self.annotation_lines = annotation_lines
        self.input_shape = input_shape
        self.random = random

    def __len__(self):
        return len(self.annotation_lines)

    def __getitem__(self, index):
        annotation_data = self.annotation_lines[index].rstrip().split()
        image = Image.open(annotation_data[0])
        image = self.get_random_data(image, self.input_shape, random=self.random)
        image = np.transpose(preprocess_input(np.array(image).astype(np.float32)), [2, 0, 1])

        y = int(annotation_data[1])
        return image, y

    def rand(self, a=0, b=1):
        return np.random.rand() * (b - a) + a

    def get_random_data(self, image, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
        # ------------------------------#
        #   读取图像并转换成RGB图像
        # ------------------------------#
        image = cvtColor(image)
        # ------------------------------#
        #   获得图像的高宽与目标高宽
        # ------------------------------#
        iw, ih = image.size
        h, w = input_shape

        if not random:
            scale = min(w / iw, h / ih)
            nw = int(iw * scale)
            nh = int(ih * scale)
            dx = (w - nw) // 2
            dy = (h - nh) // 2

            # ---------------------------------#
            #   将图像多余的部分加上灰条
            # ---------------------------------#
            image = image.resize((nw, nh), Image.BICUBIC)
            new_image = Image.new('RGB', (w, h), (128, 128, 128))
            new_image.paste(image, (dx, dy))
            image_data = np.array(new_image, np.float32)

            return image_data

        # ------------------------------------------#
        #   对图像进行缩放并且进行长和宽的扭曲
        # ------------------------------------------#
        new_ar = w / h * self.rand(1 - jitter, 1 + jitter) / self.rand(1 - jitter, 1 + jitter)
        scale = self.rand(.75, 1.25)
        if new_ar < 1:
            nh = int(scale * h)
            nw = int(nh * new_ar)
        else:
            nw = int(scale * w)
            nh = int(nw / new_ar)
        image = image.resize((nw, nh), Image.BICUBIC)

        # ------------------------------------------#
        #   将图像多余的部分加上灰条
        # ------------------------------------------#
        dx = int(self.rand(0, w - nw))
        dy = int(self.rand(0, h - nh))
        new_image = Image.new('RGB', (w, h), (128, 128, 128))
        new_image.paste(image, (dx, dy))
        image = new_image

        # ------------------------------------------#
        #   翻转图像
        # ------------------------------------------#
        flip = self.rand() < .5
        if flip:
            image = image.transpose(Image.FLIP_LEFT_RIGHT)

        rotate = self.rand() < .5
        if rotate:
            angle = np.random.randint(-15, 15)
            a, b = w / 2, h / 2
            M = cv2.getRotationMatrix2D((a, b), angle, 1)
            image = cv2.warpAffine(np.array(image), M, (w, h), borderValue=[128, 128, 128])

        # ------------------------------------------#
        #   色域扭曲
        # ------------------------------------------#
        hue = self.rand(-hue, hue)
        sat = self.rand(1, sat) if self.rand() < .5 else 1 / self.rand(1, sat)
        val = self.rand(1, val) if self.rand() < .5 else 1 / self.rand(1, val)
        x = cv2.cvtColor(np.array(image, np.float32) / 255, cv2.COLOR_RGB2HSV)
        x[..., 1] *= sat
        x[..., 2] *= val
        x[x[:, :, 0] > 360, 0] = 360
        x[:, :, 1:][x[:, :, 1:] > 1] = 1
        x[x < 0] = 0
        image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB) * 255

        # ------------------------------------------#
        #   给图像添加噪声
        # ------------------------------------------#
        # 椒盐噪声
        sp_n = self.rand() < .4
        if sp_n:
            prob = np.random.uniform(0, 0.005)
            image_data = sp_noise(image_data, prob)

        # 高斯噪声
        ga_n = self.rand() < .4
        if ga_n:
            mean = np.random.uniform(0, 0.02)
            var = np.random.uniform(0, 0.002)
            image_data = gaussian_noise(image_data, mean, var)

        return image_data

def sp_noise(image, prob):

    '''
    添加椒盐噪声
    prob:噪声比例
    '''

    output = image.copy()
    threshold = 1 - prob

    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            rdm = np.random.rand()
            if rdm < prob:
                output[i][j] = 0
            elif rdm > threshold:
                output[i][j] = 255

    return output

def gaussian_noise(image, mean, var):

    '''
        添加高斯噪声
        mean : 均值
        var : 方差
    '''

    image = np.array(image / 255, dtype=float)
    noise = np.random.normal(mean, var ** 0.5, image.shape)
    out = image + noise

    if out.min() < 0:
        low_clip = -1.
    else:
        low_clip = 0.

    out = np.clip(out, low_clip, 1.0)
    out = np.uint8(out * 255)

    return out

def detection_collate(batch):
    images = []
    targets = []
    for image, y in batch:
        images.append(image)
        targets.append(y)
    images  = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
    targets = torch.from_numpy(np.array(targets)).type(torch.FloatTensor).long()
    return images, targets
